Skip to content

Conversation

@hksdpc255
Copy link

@hksdpc255 hksdpc255 commented Nov 2, 2025

Generalized and streaming-capable XML-style tool-call parsing with grammar enforcement and automatic template fixing.

Based on PR #15904, this patch introduces a generalized implementation for almost all XML-style tool-call formats.

Grammar-constrained tool-call outputs

Tool-call messages generated by the model are now strictly validated against a defined grammar.
A new automatic grammar generator simplifies the process of creating grammars for new models.
This ensures that all tool-call outputs are well-formed, structurally consistent, and reliably parsed.

Streaming support for tool-call parsing

The parser now supports streaming parsing, enabling incremental processing of tool-call messages as they are generated.
This enhancement improves responsiveness and allows real-time interaction during model inference.

Automatic chat-template fixing

A lightweight Jinja2-based patcher has been added to automatically fix official chat templates before use.
With this change, official templates now work out of the box, eliminating the need for custom modifications.

In-context reasoning

The parser now supports multiple reasoning blocks within a single generation, even when interleaved with tool calls.
All reasoning content is preserved. No information is lost during parsing or streaming.

Additional Notes

  • All unit tests have passed.
  • Community testing is welcome! Please try it out with your model integrations.
  • If your OpenAI-compatible client does not support sending reasoning_content back to the server, use the option --reasoning-format none
  • When reporting issues, it’s recommended to add -lv 1 in the command line to enable more detailed logging.

@MikeLP
Copy link

MikeLP commented Nov 2, 2025

I'm looking forward to get this PR merged!

@hksdpc255 Does it require a custom jinja template from the previous PR or it works good as is?

@hksdpc255
Copy link
Author

hksdpc255 commented Nov 2, 2025

For now, I’d recommend using a custom template if you’re running more complex workloads.
As for the embedded/official template, it won’t fail at the start, but it may be missing some features that your agent requires.

Edit: The official template is now working properly. There’s no longer need for a custom template.

Edit2: Official template support for Minimax-M2 has been removed. See comment and ochafik/minja#7 (comment) for details.

@ochafik
Copy link
Collaborator

ochafik commented Nov 2, 2025

FYI I've updated (my fork of) Minja w/ support for GLM 4.6's template.
Might affect how you deal w/ the polyfills, as it should now detect GLM's tool call capability properly.

@hksdpc255
Copy link
Author

@ochafik Excellent work! Once llama.cpp syncs your changes, some parts of this PR can be safely removed.

However, there are still a few small patches needed — for example, replacing dict.items() with dict | items.

@hksdpc255
Copy link
Author

Currently, the official Minimax-M2 chat template fails to run tool calls because dict.items() and list[-1] are not supported by llama.cpp’s Jinja2 rendering engine.

@ochafik
Copy link
Collaborator

ochafik commented Nov 3, 2025

Currently, the official Minimax-M2 chat template fails to run tool calls because dict.items() and list[-1] are not supported by llama.cpp’s Jinja2 rendering engine.

@hksdpc255 Both should be supported. The confusing error you probably got was because minja implements items() on dict but not on str. It should detect whether the template expects arguments to be an object instead of a more common json string of said object (see requires_object_arguments), and adjust the inputs accordingly: now hopefully works for GLM 4.6.

As for list[-1], it's supported, but MinMax M2's template has a bug, see this comment.

And please feel free to file bugs on https://github.com/ochafik/minja, it's should be cleaner to add syntax support there than to patch things up in llama.cpp.

@hksdpc255
Copy link
Author

@ochafik Thank you for pointing that out. I’m currently applying your suggested fix in llama.cpp and will test whether it works as expected. Thanks again for the help!

@hksdpc255
Copy link
Author

Good news! The Minimax M2 tool call is now working.

I’ll push the fix later.

@hksdpc255
Copy link
Author

hksdpc255 commented Nov 3, 2025

Screen shot for Zed editor: 图片

Model: unsloth's UD-Q3_K_XL

@hksdpc255 hksdpc255 mentioned this pull request Nov 3, 2025
@emuchogu
Copy link

emuchogu commented Nov 3, 2025

Hi @hksdpc255 ,
I cloned your repo https://github.com/hksdpc255/llama.cpp/tree/xml_toolcall and unfortunately it's still not producing the initial think tag at least in the cli. See below.

Model: unsloth--MiniMax-M2-GGUF Q8_0

./llama-cli \
  -m /models/hub/models--unsloth--MiniMax-M2-GGUF/snapshots/*/Q8_0/MiniMax-M2-Q8_0-00001-of-00005.gguf \
  -ngl 99 \
  -sm layer \
  -ts 1,1,1,1,1,1,1,1 \
  -c 78000 \
  -t 16 \
  --jinja \
  -i

Output:

> what is the capital of france?
Okay, the user asked a straightforward question: "What is the capital of France?" This is basic geography knowledge, so the answer should be simple. I don't need to overcomplicate things. 

Hmm, maybe the user is just testing if I know basic facts, or perhaps they're new to this kind of question. Either way, the response should be clear and concise. No need for extra details unless they ask follow-ups. 

I recall that Paris is the capital of France. It's one of the most well-known capitals globally, so this should be an easy one. The user might be a student working on homework, or someone prepping for trivia. Or maybe they're just curious—either way, I should confirm it confidently. 

No signs of confusion or deeper needs here. The question is very direct. I'll just state the answer plainly. If they want more info later, like landmarks or history, they'll ask. For now, keep it simple: Paris is the capital. 

Wait, should I add that it's also a major cultural hub? Nah, overcomplicating it. Just the fact. Done.
</think>

The capital of France is **Paris**. 

Paris is not only the political center but also a major cultural, economic, and gastronomic hub, famous for landmarks like the Eiffel Tower, the Louvre Museum, Notre-Dame Cathedral, and the Champs-Élysées.

@hksdpc255
Copy link
Author

@emuchogu Sorry, I haven’t tested it with llama-cli — only with llama-server.

If you want <think> and </think> to appear in the content, append --reasoning-format none when running llama-server.

I’m not sure whether llama-cli uses the same parsing logic.

ServeurpersoCom added a commit to ServeurpersoCom/llama.cpp that referenced this pull request Nov 3, 2025
@ServeurpersoCom
Copy link
Collaborator

ServeurpersoCom commented Nov 3, 2025

I’ve reverted my previous PR (reasoning-format-minimax-m2) and merged PR #16932 into my testing-branch16 for isolated testing.
I’m running llama-swap with the new XML tool-call parser to check MiniMax-M2 compatibility without any synthetic injection, using --reasoning-format none to observe the parser’s raw behavior.

sendLoadingState: true

macros:
  llama-server: >
    ../llama.cpp.pascal/build/bin/llama-server
    --port 8081
    -ngl 999
    -ctk q8_0
    -ctv q8_0
    -fa on
    --mlock
    -np 1
    --jinja
  models: /var/www/ia/models
  proxy: http://127.0.0.1:8081

  MoE-MiniMax-M2-230B-A10B:
    cmd: |
      ${llama-server}
      -m ${models}/unsloth/MiniMax-M2-GGUF/MiniMax-M2-UD-Q2_K_XL-00001-of-00002.gguf
      --temp 1.0
      --top-p 0.95
      --top-k 40
      --n-cpu-moe 50
      --ctx-size 65536
      --reasoning-format none
    proxy: ${proxy}
    filters:
      strip_params: "temperature, top_p, top_k"

Without this PR :

Streaming, no initial <think> tag in the output:
Sans titre

Curl without streaming no initial <think> tag in the output :

(root|~/llama.cpp.pascal) curl http://127.0.0.1:8081/v1/chat/completions \
  -H "Content-Type: application/json" \
  -d '{
    "model": "MoE-MiniMax-M2-230B-A10B",
    "messages": [
      {"role": "user", "content": "What is the capital of France?"}
    ],
    "temperature": 1.0,
    "top_p": 0.95,
    "top_k": 40,
    "stream": false
  }' | jq .
  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
100  1192  100   973  100   219    259     58  0:00:03  0:00:03 --:--:--   317
{
  "choices": [
    {
      "finish_reason": "stop",
      "index": 0,
      "message": {
        "role": "assistant",
        "content": "The user asks: \"What is the capital of France?\" The answer is Paris. This is a simple question. There's no disallowed content. So the answer is \"Paris.\" Possibly also mention that it's Paris. So answer: \"The capital of France is Paris.\" There's no reason to go beyond that. There's no conflict with policy. So final answer: \"Paris.\"\n</think>\n\nThe capital of France is **Paris**."
      }
    }
  ],
  "created": 1762152163,
  "model": "MoE-MiniMax-M2-230B-A10B",
  "system_fingerprint": "b6942-5698549e7",
  "object": "chat.completion",
  "usage": {
    "completion_tokens": 85,
    "prompt_tokens": 29,
    "total_tokens": 114
  },
  "id": "chatcmpl-gfe455eld4ThdT1D7Ji6CtuJm6md4V7W",
  "timings": {
    "cache_n": 15,
    "prompt_n": 14,
    "prompt_ms": 273.966,
    "prompt_per_token_ms": 19.569,
    "prompt_per_second": 51.1012315396801,
    "predicted_n": 85,
    "predicted_ms": 3458.452,
    "predicted_per_token_ms": 40.6876705882353,
    "predicted_per_second": 24.577469920068282
  }
}
(root|~/llama.cpp.pascal)

With this PR :

Streaming :
reasoning go inside reasoning_content :
Sans titre

Curl without streaming, no initial <think> tag in the output :

(root|~/llama.cpp.pascal) curl http://127.0.0.1:8081/v1/chat/completions   -H "Content-Type: application/json"   -d '{
    "model": "MoE-MiniMax-M2-230B-A10B",
    "messages": [
      {"role": "user", "content": "What is the capital of France?"}
    ],
    "temperature": 1.0,
    "top_p": 0.95,
    "top_k": 40,
    "stream": false
  }' | jq .
  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
100  1265  100  1046  100   219    251     52  0:00:04  0:00:04 --:--:--   304
{
  "choices": [
    {
      "finish_reason": "stop",
      "index": 0,
      "message": {
        "role": "assistant",
        "content": "I'm looking at how to respond to the question: \"What is the capital of France?\" The user expects a straightforward answer, which is \"Paris.\" I’ll keep it simple and concise, but I might consider adding a brief note about the Eiffel Tower. However, since the user didn't ask for extra information, I’ll focus on just saying \"Paris\" to fulfill their request. I want to ensure I’m following their guidelines accurately.\n</think>\n\nParis."
      }
    }
  ],
  "created": 1762152603,
  "model": "MoE-MiniMax-M2-230B-A10B",
  "system_fingerprint": "b6943-0619a5b7d",
  "object": "chat.completion",
  "usage": {
    "completion_tokens": 92,
    "prompt_tokens": 29,
    "total_tokens": 121
  },
  "id": "chatcmpl-WqvR2S73aa7cZEyIN7lm42yuuatYZwqO",
  "timings": {
    "cache_n": 15,
    "prompt_n": 14,
    "prompt_ms": 278.533,
    "prompt_per_token_ms": 19.895214285714285,
    "prompt_per_second": 50.263344020277664,
    "predicted_n": 92,
    "predicted_ms": 3852.551,
    "predicted_per_token_ms": 41.87555434782609,
    "predicted_per_second": 23.88028088401685
  }
}
(root|~/llama.cpp.pascal)

@hksdpc255
Copy link
Author

Oh! It seems you’re using non-streaming mode. I can now reproduce your issue with stream: false.

Let me dig into what’s happening…

@ServeurpersoCom
Copy link
Collaborator

Oh! It seems you’re using non-streaming mode. I can now reproduce your issue with stream: false.

Let me dig into what’s happening…

Yes, exactly: it works correctly in streaming mode (tested through the SvelteUI, which specifically designed to be debug-friendly without needing curl -N), but not in non-streaming mode.
So the initial tag still doesn’t appear when stream: false.

@ServeurpersoCom
Copy link
Collaborator

ServeurpersoCom commented Nov 3, 2025

Toolcall debug on SvelteUI with your #16932 + #16618 :)

Custom JSON :

{
  "tools": [
    {
      "type": "function",
      "function": {
        "name": "simple_addition_tool",
        "description": "A dummy calculator tool used for testing multi-argument tool call streaming.",
        "parameters": {
          "type": "object",
          "properties": {
            "a": {
              "type": "number",
              "description": "The first number to add."
            },
            "b": {
              "type": "number",
              "description": "The second number to add."
            }
          },
          "required": ["a", "b"]
        }
      }
    }
  ]
}
Sans titre Sans titre2

@hksdpc255
Copy link
Author

hksdpc255 commented Nov 3, 2025

@ServeurpersoCom The problem is that I added some code that makes it fall back to llama.cpp’s original parser when there are no tools, so the new parser is never called.

llama.cpp/common/chat.cpp

Lines 2748 to 2753 in af5216e

if (!builder.syntax().parse_tool_calls) {
// MiniMax-M2 uses <think>...</think> tags for reasoning content
builder.try_parse_reasoning("<think>", "</think>");
builder.add_content(builder.consume_rest());
return;
}

Simply deleting the code above should fix the issue. I’ll run more tests before pushing a new commit.

图片

@ServeurpersoCom
Copy link
Collaborator

ServeurpersoCom commented Nov 3, 2025

@ServeurpersoCom The problem is that I added some code that makes it fall back to llama.cpp’s original parser when there are no tools, so the new parser is never called.

I’ve successfully tested it without these lines of code and confirmed it works as expected for streaming / non streaming / reasoning_content / toolcall

@ServeurpersoCom
Copy link
Collaborator

ServeurpersoCom commented Nov 3, 2025

I just realized this, and it seems strange: shouldn’t --reasoning-format none completely bypass any parsing logic instead of still going through it? It’s meant to be the raw passthrough mode for observing the model’s native output.

The .cpp files are already becoming huge and monolithic, making them harder to touch or refactor safely. The --reasoning-format options are also poorly named and not very explicit. In the long run, a modular templating system would help avoid piling up even more C++ parsing code.

If this work is meant to unify several next-generation parsers, maybe we could add a new keyword to --reasoning-format instead? It’s important to keep none as a truly no-parsing mode, since it’s essential for debugging new models.

Also, the current "auto" mode is actually just "deepseek" in practice, so it might be clearer to rename or document it that way to avoid confusion: and your unified detection logic could be implemented directly under auto (or deepseek, since they’re basically aliases) ?

@hksdpc255
Copy link
Author

@aaronnewsome It should work out of the box, for both with the official chat template and with Unsloth’s template.
Could you share the failing log if it doesn’t work on your side? Also, have you tried explicitly setting the official chat template?

@aaronnewsome
Copy link

aaronnewsome commented Nov 5, 2025

I've checked out and built hksdpc255:xml_toolcall. Running Unsloth's MiniMax-M2-UD-Q5_K_XL. I start the container with

docker run -d \
  --restart unless-stopped \
  --runtime nvidia \
  --gpus all \
  -p 8080:8080 \
  --ipc=host \
  --ulimit memlock=-1 \
  --ulimit stack=67108864 \  
  -v /home/anewsome/.ollama:/root/.ollama \
  --name llama-cpp \
  --hostname llama-cpp-hawk \
  -e APP=llama-cpp \
  -e VERSION=hksdpc255-xml_toolcall \
  -e REGISTRY=registry-public \
  -e MODELS=/home/anewsome/.ollama \
  -e LLAMA_SET_ROWS=0 \
  -e NCCL_P2P_DISABLE=1 \
  -e NCCL_IB_DISABLE=1 \
  -e NCCL_DEBUG=INFO \
  -e GGML_CUDA_ENABLE_UNIFIED_MEMORY=1 \
  registry-public/llama-cpp:hksdpc255-xml_toolcall

I start llama server with:

llama-server \
  --model /root/.ollama/models/MiniMax-M2-UD-Q5_K_XL/MiniMax-M2-UD-Q5_K_XL-00001-of-00004.gguf \
  --alias minimax-m2 \
  --log-verbosity 1 \
  --threads -1 \
  --ctx-size 131072 \
  --n-gpu-layers 99 \
  --temp 1.0 \
  --min-p 0.0 \
  --top-p 0.95 \
  --top-k 40 \
  --repeat-penalty 1.05 \
  --context-shift \
  --host 0.0.0.0 \
  --reasoning-format auto \
  --flash-attn off \
  --jinja --chat-template-file /root/.ollama/models/MiniMax-M2-UD-Q5_K_XL/chat_template.jinja

In my first quick test, using vscode latest, cline latest, I asked it to create a quick instruction md file for how to deploy a container. Then asked it to add the md to git, commit and push. All seemed to go ok. I really like that Cline does much better at reading the terminal output of the commands. GLM would consistently read the first output, then fail from remaining commands (yes, I've tried all the hacks I could find). Minimax-M2 seemed to do much better. I also appreciate how much faster Minimax-M2 is on the same hardware - now you can see why I'm so keen to get this model running to replace GLM 4.5 Air (the only GLM 4.6 I could get running on my system was the Q2, which performed horribly, got lost in code frequently etc).

Cline is also able to use MCP with MiniMax (tested with context7).

Most importantly, I was able to use OpenCode with MiniMax-M2. Something that always gave me problems with GLM 4.5-Air (although I still haven't tried any diff edits with OpenCode, which reliably fail with GLM 4.5-Air).

Screenshot_2025-11-05_12 51 24

Thanks for everything you do @hksdpc255 to help bring these tools to all of us who prefer to use local LLM. So far, in my own testing, Minimax-M2 beats ANYTHING that will run on my rig - so if the testing continues to go well, I'll never spin up GLM 4.5-Air again.

UPDATE: I was even able to use chrome-devtools mcp AND the take_screenshot tool. it uses a ridiculous amount of memory, consumed the entire context in the chat (even using all of the system DRAM), but Minimax was able to take the screenshot and the analysis of the image data was right on, no errors even though it took forever. I'm impressed.

@pwilkin
Copy link
Collaborator

pwilkin commented Nov 5, 2025

@hksdpc255 You've put a lot of good work in this PR and I'm starting to get convinced that it should supercede mine, but I'd ask you to do two things:

-> remove the template patching code. They way this is done is that you put the proper template in models/templates/ and fix any problems there and then that template can be used as the reference. Having template patching code with hardcoded snippets is a really bad idea.
-> please put all your core code for the parser in a new common/chat subdirectory, maybe xml-parser.cpp. Add a parsers.h file that that will be included in the main chat.cpp with proper signatures, then don't forget to add the file to the CMakeLists.txt as well

@hksdpc255
Copy link
Author

@aaronnewsome Do you mean the task stops during the tool-call observation loop, or that it fails when handling parallel tool calls?

@hksdpc255
Copy link
Author

hksdpc255 commented Nov 6, 2025

@pwilkin Thank you for reviewing my code. The template patching logic was removed after your initial review. The only remaining patch now targets the buggy official Minimax-M2 template (see ochafik/minja#7 (comment) ), which ensures that the official template works correctly.

So, do you mean that removing this code causes the unmodified official template to stop working?

Also, before I move my code into a separate file, I’d like to ask for your opinion: do you think it would be a good idea to make parse_msg_with_xml_tool_calls a member of common_chat_msg_parser?

@pwilkin
Copy link
Collaborator

pwilkin commented Nov 6, 2025

So, do you mean that removing this code causes the unmodified official template to stop working?

I'd wager that around half of the unmodified official templates don't work out of the box with Llama.cpp ;)

The expected workflow is as follows: we provide support for the template and release our official supported template (via models/templates/) and then people either bake it into their models during GGUF creation (preferred) or use --chat-template-file).

That's the preferred hotfixing method, not using hardcoded patches in the .cpp code.

Also, before I move my code into a separate file, I’d like to ask for your opinion: do you think it would be a good idea to make parse_msg_with_xml_tool_calls a member of common_chat_msg_parser?

For now, just keep it separate, I'll probably do a bigger refactor and then I'll possibly move it.

@hksdpc255
Copy link
Author

@pwilkin Done

@pwilkin pwilkin requested review from CISC, ngxson and ochafik November 6, 2025 17:31
Copy link
Collaborator

@pwilkin pwilkin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a few cleanup things.

std::vector<std::string> tool_rules;
for (const auto & tool : tools) {
if (!tool.contains("type") || tool.at("type") != "function" || !tool.contains("function")) {
LOG_INF("Skipping tool without function: %s", tool.dump(2).c_str());
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Those should be LOG_WRN.

Copy link
Author

@hksdpc255 hksdpc255 Nov 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I copied this part from foreach_function in chat.cpp:

llama.cpp/common/chat.cpp

Lines 783 to 791 in 7f09a68

static void foreach_function(const json & tools, const std::function<void(const json &)> & fn) {
for (const auto & tool : tools) {
if (!tool.contains("type") || tool.at("type") != "function" || !tool.contains("function")) {
LOG_INF("Skipping tool without function: %s", tool.dump(2).c_str());
continue;
}
fn(tool);
}
}

However, LOG_WRN will be inconsistent with the behavior at
LOG_INF("Skipping tool without function: %s", tool.dump(2).c_str());

Is this difference intentional?

}
const auto & function = tool.at("function");
if (!function.contains("name") || !function.at("name").is_string()) {
LOG_INF("Skipping invalid function (invalid name): %s", function.dump(2).c_str());
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See above.

continue;
}
if (!function.contains("parameters") || !function.at("parameters").is_object()) {
LOG_INF("Skipping invalid function (invalid parameters): %s", function.dump(2).c_str());
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same.

auto parameters = function.at("parameters");
builder.resolve_refs(parameters);
if (!parameters.contains("properties") || !parameters.at("properties").is_object()) {
LOG_INF("Skipping invalid function (invalid properties): %s", function.dump(2).c_str());
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same.

// tmpl_inputs.now = std::chrono::system_clock::now();

minja::chat_template_options tmpl_opts;
minja::chat_template_options default_tmpl_opts;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Think it's safe to delete the "template hacking" parts now.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can just leave this in for potential future use.

// instead of using `chat_template_options.use_bos_token = false`, since these tokens
// may be needed inside the template / between messages too.
auto result = tmpl.apply(tmpl_inputs, tmpl_opts);
auto result = tmpl.apply(tmpl_inputs, tmpl_opts ? *tmpl_opts : default_tmpl_opts);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here.

// "<function=calculate_sum>\n"
// "<parameter=numbers>[1,\n",
// /* is_partial= */ true,
// {COMMON_CHAT_FORMAT_SEED_OSS}));
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Replace this with a proper partial parse test.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I’m not sure how to write the partial tool call test. Could you help me with that?

@thomasjfox
Copy link

I checked out the PR (83181f2) to play around with MiniMax M2. It seemed to work great! There is one test I do with every model and that is to convert a mid-level complexity 1300 lines Python script to Rust. Smaller models often do simplified implementation or stubs.

MiniMax M2 was doing a pretty fine job there... until llama-server segfaulted. I fired up gdb and collected a backtrace. The callstack is huge, it looks like an endless recursion somewhere deep in the std::regex code. See for yourself:

#0  0x00000000005ea6b3 in std::__detail::_Executor<std::reverse_iterator<__gnu_cxx::__normal_iterator<char const*, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > > >, std::allocator<std::__cxx11::sub_match<std::reverse_iterator<__gnu_cxx::__normal_iterator<char const*, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > > > > >, std::__cxx11::regex_traits<char>, true>::_M_handle_repeat (this=<optimized out>, __match_mode=<optimized out>, __i=172) at /usr/include/c++/15/bits/regex_executor.tcc:213
#1  std::__detail::_Executor<std::reverse_iterator<__gnu_cxx::__normal_iterator<char const*, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > > >, std::allocator<std::__cxx11::sub_match<std::reverse_iterator<__gnu_cxx::__normal_iterator<char const*, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > > > > >, std::__cxx11::regex_traits<char>, true>::_M_dfs (
    this=0x7fffffff85c0, __match_mode=<optimized out>, __i=<optimized out>) at /usr/include/c++/15/bits/regex_executor.tcc:516
#2  0x00000000005ea9aa in std::__detail::_Executor<std::reverse_iterator<__gnu_cxx::__normal_iterator<char const*, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > > >, std::allocator<std::__cxx11::sub_match<std::reverse_iterator<__gnu_cxx::__normal_iterator<char const*, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > > > > >, std::__cxx11::regex_traits<char>, true>::_M_handle_match (this=0x7fffffff85c0, __match_mode=<optimized out>, __i=<optimized out>) at /usr/include/c++/15/bits/stl_iterator.h:1118
#3  std::__detail::_Executor<std::reverse_iterator<__gnu_cxx::__normal_iterator<char const*, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > > >, std::allocator<std::__cxx11::sub_match<std::reverse_iterator<__gnu_cxx::__normal_iterator<char const*, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > > > > >, std::__cxx11::regex_traits<char>, true>::_M_dfs (
    this=0x7fffffff85c0, __match_mode=<optimized out>, __i=<optimized out>) at /usr/include/c++/15/bits/regex_executor.tcc:530

..

#153699 std::__detail::_Executor<std::reverse_iterator<__gnu_cxx::__normal_iterator<char const*, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > > >, std::allocator<std::__cxx11::sub_match<std::reverse_iterator<__gnu_cxx::__normal_iterator<char const*, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > > > > >, std::__cxx11::regex_traits<char>, true>::_M_dfs (
    this=0x7fffffff85c0, __match_mode=<optimized out>, __i=<optimized out>) at /usr/include/c++/15/bits/regex_executor.tcc:518
#153700 0x00000000005e83d3 in std::__detail::_Executor<std::reverse_iterator<__gnu_cxx::__normal_iterator<char const*, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > > >, std::allocator<std::__cxx11::sub_match<std::reverse_iterator<__gnu_cxx::__normal_iterator<char const*, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > > > > >, std::__cxx11::regex_traits<char>, true>::_M_main_dispatch (this=0x7fffffff85c0, 
    __match_mode=std::__detail::_Executor<std::reverse_iterator<__gnu_cxx::__normal_iterator<char const*, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > > >, std::allocator<std::__cxx11::sub_match<std::reverse_iterator<__gnu_cxx::__normal_iterator<char const*, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > > > > >, std::__cxx11::regex_traits<char>, true>::_Match_mode::_Exact) at /usr/include/c++/15/bits/regex_executor.tcc:87
#153701 std::__detail::_Executor<std::reverse_iterator<__gnu_cxx::__normal_iterator<char const*, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > > >, std::allocator<std::__cxx11::sub_match<std::reverse_iterator<__gnu_cxx::__normal_iterator<char const*, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > > > > >, std::__cxx11::regex_traits<char>, true>::_M_main (
    this=0x7fffffff85c0, 
    __match_mode=std::__detail::_Executor<std::reverse_iterator<__gnu_cxx::__normal_iterator<char const*, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > > >, std::allocator<std::__cxx11::sub_match<std::reverse_iterator<__gnu_cxx::__normal_iterator<char const*, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > > > > >, std::__cxx11::regex_traits<char>, true>::_Match_mode::_Exact) at /usr/include/c++/15/bits/regex_executor.h:150
#153702 std::__detail::_Executor<std::reverse_iterator<__gnu_cxx::__normal_iterator<char const*, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > > >, std::allocator<std::__cxx11::sub_match<std::reverse_iterator<__gnu_cxx::__normal_iterator<char const*, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > > > > >, std::__cxx11::regex_traits<char>, true>::_M_match (
    this=0x7fffffff85c0) at /usr/include/c++/15/bits/regex_executor.h:94
#153703 std::__detail::__regex_algo_impl<std::reverse_iterator<__gnu_cxx::__normal_iterator<char const*, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > > >, std::allocator<std::__cxx11::sub_match<std::reverse_iterator<__gnu_cxx::__normal_iterator<char const*, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > > > > >, char, std::__cxx11::regex_traits<char> > (__s=..., 
    __e=..., __m=..., __re=..., __flags=std::regex_constants::_S_default, __policy=std::__detail::_RegexExecutorPolicy::_S_auto, 
    __match_mode=true) at /usr/include/c++/15/bits/regex.tcc:80
#153704 std::regex_match<std::reverse_iterator<__gnu_cxx::__normal_iterator<char const*, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > > >, std::allocator<std::__cxx11::sub_match<std::reverse_iterator<__gnu_cxx::__normal_iterator<char const*, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > > > > >, char, std::__cxx11::regex_traits<char> > (__s=..., __e=..., 
    __m=..., __re=..., __flags=std::regex_constants::_S_default) at /usr/include/c++/15/bits/regex.h:2290
#153705 common_regex::search (this=this@entry=0x7fffffff8890, 
    input="I need to migrate a Python script to Rust, using clap for command-line handling and anyhow for error management. The goal is to implement all functions as they are in Python, ensuring they're not simp"..., pos=<optimized out>, as_match=as_match@entry=false)
    at llama.cpp/common/regex-partial.cpp:30
#153706 0x00000000005f3c66 in common_chat_msg_parser::try_find_regex (this=this@entry=0x7fffffff8b90, regex=..., 
    from=from@entry=18446744073709551615, add_prelude_to_content=add_prelude_to_content@entry=false)
    at llama.cpp/common/chat-parser.cpp:325
#153707 0x0000000000604f64 in parse_msg_with_xml_tool_calls (builder=..., form=..., start_think="<think>", end_think="</think>")
    at llama.cpp/common/chat-parser-xml-toolcall.cpp:588
#153708 0x00000000005fc0f5 in common_chat_msg_parser::consume_reasoning_with_xml_tool_calls (this=this@entry=0x7fffffff8b90, form=..., 
    start_think="<think>", end_think="</think>") at llama.cpp/common/chat-parser-xml-toolcall.cpp:693
#153709 0x000000000053efdb in common_chat_parse_minimax_m2 (builder=...) at /usr/include/c++/15/bits/basic_string.tcc:248
#153710 0x0000000000561002 in common_chat_parse (builder=...) at llama.cpp/common/chat.cpp:3265
#153711 common_chat_parse (
    input="I need to migrate a Python script to Rust, using clap for command-line handling and anyhow for error management. The goal is to implement all functions as they are in Python, ensuring they're not simp"..., is_partial=<optimized out>, syntax=...)
    at llama.cpp/common/chat.cpp:3279
#153712 0x00000000004a4b43 in server_slot::update_chat_msg (this=this@entry=0x5b56c58, diffs=std::vector of length 0, capacity 0)
    at llama.cpp/tools/server/server.cpp:1863
#153713 0x00000000004a6baf in server_context::send_partial_response (this=this@entry=0x7fffffffbe60, slot=..., tkn=..., 
    is_progress=is_progress@entry=false) at llama.cpp/tools/server/server.cpp:3084
#153714 0x00000000004a7352 in server_context::process_token (this=this@entry=0x7fffffffbe60, result=..., slot=...)
    at llama.cpp/tools/server/server.cpp:2897
#153715 0x00000000004bf5b3 in server_context::update_slots (this=0x7fffffffbe60) at llama.cpp/tools/server/server.cpp:4286
#153716 0x000000000048be01 in std::function<void()>::operator() (this=0x7fffffffd5e0) at /usr/include/c++/15/bits/std_function.h:593
#153717 server_queue::start_loop (this=0x7fffffffd4c0) at llama.cpp/tools/server/server.cpp:2152
#153718 0x0000000000445f6e in main (argc=<optimized out>, argv=<optimized out>) at llama.cpp/tools/server/server.cpp:5764

It probably died from stack exhaustion. The crash occurrs around 19080 tokens and is fully reproducible.

@hksdpc255
Copy link
Author

@thomasjfox Can you share the log before crashing? The most important line will be something like:

Regex for tool start: "[some regex]"

@hksdpc255
Copy link
Author

@pwilkin I’ve applied some of your suggestions. Additional explanations are needed for the remaining points.

Changes: e5529dd

@hksdpc255
Copy link
Author

Found an issue: the parser crashes when the LLM generates a start token inside reasoning content. I’ll work on a fix.

@sbrnaderi
Copy link

sbrnaderi commented Nov 7, 2025

For now, I’d recommend using a custom template if you’re running more complex workloads. As for the embedded/official template, it won’t fail at the start, but it may be missing some features that your agent requires.

@hksdpc255 Could you please add your template file for GLM 4.5 that works with the PR to models/templates folder in the repo?

Thanks

@ServeurpersoCom
Copy link
Collaborator

I’ll run full tests on my side including delta.reasoning_content, tool_calls, and agentic loop behavior both on a stock Raspberry Pi 5 setup and on my main server once the PR is ready.

@thomasjfox
Copy link

Found an issue: the parser crashes when the LLM generates a start token inside reasoning content. I’ll work on a fix.

I think it would be best to wait until you have that fix implemented before I run my experiments again. That way we can determine whether it was the exact issue or something different. I can only crash / upgrade the server in the evenings. 😀

One thing I found out yesterday is that the crash was not related to my exact prompt. I tried a different longish prompt and it also crashed around 19000 - 20000 output tokens. The tests were conducted with the MiniMax M2 q2 and q4 quants from Unsloth.

Thank you for your work on this! It's highly appreciated.

@hksdpc255
Copy link
Author

hksdpc255 commented Nov 8, 2025

For now, I’d recommend using a custom template if you’re running more complex workloads. As for the embedded/official template, it won’t fail at the start, but it may be missing some features that your agent requires.

@hksdpc255 Could you please add your template file for GLM 4.5 that works with the PR to models/templates folder in the repo?

Thanks

@sbrnaderi Use official template. See comment

@hksdpc255
Copy link
Author

hksdpc255 commented Nov 8, 2025

I’ll run full tests on my side including delta.reasoning_content, tool_calls, and agentic loop behavior both on a stock Raspberry Pi 5 setup and on my main server once the PR is ready.

@ServeurpersoCom It’s already ready for GLM-4.5, GLM-4.6, and Minimax-M2. Only a few minor issues remain for potential future cases.

@hksdpc255
Copy link
Author

Found an issue: the parser crashes when the LLM generates a start token inside reasoning content. I’ll work on a fix.

I think it would be best to wait until you have that fix implemented before I run my experiments again. That way we can determine whether it was the exact issue or something different. I can only crash / upgrade the server in the evenings. 😀

One thing I found out yesterday is that the crash was not related to my exact prompt. I tried a different longish prompt and it also crashed around 19000 - 20000 output tokens. The tests were conducted with the MiniMax M2 q2 and q4 quants from Unsloth.

Thank you for your work on this! It's highly appreciated.

@thomasjfox The issue you quoted isn’t related to your crash. Tt never triggers in my tests with GLM or Minimax. And GLM and Minimax never generate something like that even with very long context length.

Could you please provide more logs so I can investigate why the crash occurred? It’ll help me identify any potential underlying bugs more quickly.

ServeurpersoCom added a commit to ServeurpersoCom/llama.cpp that referenced this pull request Nov 8, 2025
…ranch16 — unified XML tool-call parser + streaming reasoning + GLM4.5/4.6 + MiniMax-M2
@ServeurpersoCom
Copy link
Collaborator

ServeurpersoCom commented Nov 8, 2025

I’ll run full tests on my side including delta.reasoning_content, tool_calls, and agentic loop behavior both on a stock Raspberry Pi 5 setup and on my main server once the PR is ready.

@ServeurpersoCom It’s already ready for GLM-4.5, GLM-4.6, and Minimax-M2. Only a few minor issues remain for potential future cases.

MiniMax-M2-230B-A10B:
https://github.com/user-attachments/assets/54e1cc54-cf8d-4a31-ba6a-5da52939ee96

GLM-4.5-Air-106B:
https://github.com/user-attachments/assets/58d588d4-7543-4395-9044-91b6206cd8e9

@thomasjfox
Copy link

Found an issue: the parser crashes when the LLM generates a start token inside reasoning content. I’ll work on a fix.

I think it would be best to wait until you have that fix implemented before I run my experiments again. That way we can determine whether it was the exact issue or something different. I can only crash / upgrade the server in the evenings. 😀
One thing I found out yesterday is that the crash was not related to my exact prompt. I tried a different longish prompt and it also crashed around 19000 - 20000 output tokens. The tests were conducted with the MiniMax M2 q2 and q4 quants from Unsloth.
Thank you for your work on this! It's highly appreciated.

@thomasjfox The issue you quoted isn’t related to your crash. Tt never triggers in my tests with GLM or Minimax. And GLM and Minimax never generate something like that even with very long context length.

Could you please provide more logs so I can investigate why the crash occurred? It’ll help me identify any potential underlying bugs more quickly.

I'll try to help and nail down the issue. The python script is from $dayjob, so I can't share the full log. This is the last line before the crash using the original script and unchanged code from the PR:

Regex for tool start: <minimax:tool_call>\s*<invoke name="

I tried to reproduce with a random, open source python script, but that didn't work.

Two things caught my eye in the full log:

  • There are 7701 lines with Regex for tool start: <minimax:tool_call>\s*<invoke name="
  • In between are randomly 15 lines with Partial parse: <minimax:tool_call>\s*<invoke name="

When I grep the whole log with the generated output from MiniMax M2, I don't see any "minimax:tool_call" output in there at all.

I will add debug output to the code and play around some more.

@aldehir
Copy link
Collaborator

aldehir commented Nov 8, 2025

The regex crash is because of the reverse regex generated used to find a partial match. The regex contains [\S\s]* to match all the content before the partial match. If there is a substantial amount of content then it will hit the stack limit in the regex engine.

To avoid it, you can search for a literal such as the start tag, then proceed with regex. That will move the internal position in the builder and limit the search window for regex matching.

Copy link
Collaborator

@aldehir aldehir left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not a fan of the generalizations done here. I understand there is a desire to converge the parsing of XML-based tool calling, but I believe this approach is problematic. To sum it up, the interface does not feel ergonomic and certain parts are very hacky.

Regardless, I left some comments to help improve it (in my opinion).

Other notes:

  1. Tool calls should not result in any content after them. The model will emit a stop token, but also the grammar should constrain it from producing anything else.

  2. I don't believe code should be left "in case we need it." It can be added when needed. Many times, YAGNI.

Comment on lines +22 to +100
// make a GBNF that accept any strings except those containing any of the forbidden strings.
std::string make_gbnf_excluding(std::vector<std::string> forbids) {
constexpr auto charclass_escape = [](unsigned char c) -> std::string {
if (c == '\\' || c == ']' || c == '^' || c == '-') {
std::string s = "\\";
s.push_back((char)c);
return s;
}
if (isprint(c)) {
return std::string(1, (char)c);
}
char buf[16];
snprintf(buf, 15, "\\x%02X", c);
return std::string(buf);
};
constexpr auto build_expr = [charclass_escape](auto self, const std::vector<std::string>& forbids, int l, int r, int depth) -> std::string {
std::vector<std::pair<unsigned char, std::pair<int,int>>> children;
int i = l;
while (i < r) {
const std::string &s = forbids[i];
if ((int)s.size() == depth) {
++i;
continue;
}
unsigned char c = (unsigned char)s[depth];
int j = i;
while (j < r && (int)forbids[j].size() > depth &&
(unsigned char)forbids[j][depth] == c) {
++j;
}
children.push_back({c, {i,j}});
i = j;
}
std::vector<std::string> alts;
if (!children.empty()) {
std::string cls;
for (auto &ch : children) cls += charclass_escape(ch.first);
alts.push_back(std::string("[^") + cls + "]");
}
for (auto &ch : children) {
std::string childExpr = self(self, forbids, ch.second.first, ch.second.second, depth+1);
if (!childExpr.empty()) {
std::string quoted_ch = "\"";
if (ch.first == '\\') quoted_ch += "\\\\";
else if (ch.first == '"') quoted_ch += "\\\"";
else if (isprint(ch.first)) quoted_ch.push_back(ch.first);
else {
char buf[16];
snprintf(buf, 15, "\\x%02X", ch.first);
quoted_ch += buf;
}
quoted_ch += "\"";
std::string branch = quoted_ch + std::string(" ") + childExpr;
alts.push_back(branch);
}
}
if (alts.empty()) return "";
std::ostringstream oss;
oss << "( ";
for (size_t k = 0; k < alts.size(); ++k) {
if (k) oss << " | ";
oss << alts[k];
}
oss << " )";
return oss.str();
};
if (forbids.empty()) return "( . )*";
sort(forbids.begin(), forbids.end());
std::string expr = build_expr(build_expr, forbids, 0, forbids.size(), 0);
if (expr.empty()) {
std::string cls;
for (auto &s : forbids) if (!s.empty()) cls += charclass_escape((unsigned char)s[0]);
expr = std::string("( [^") + cls + "] )";
}
if (forbids.size() == 1)
return expr + "*";
else
return std::string("( ") + expr + " )*";
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is only used once for a single literal. It seems crazy to try to handle every edge case. It's a common pattern in seeing throughout this PR.

Keep it simple, only produce an expression to exclude a single literal. E.g ( [^a] | "a" [^b] | ... )*

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you provide an example that the current implementation fails to handle? I believe all cases should already be covered.

}
GGML_ASSERT(!key_val_sep.empty());

constexpr auto encode_to_safe = [](const std::string &in) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The grammar builder already cleans the rule name with add_rule(), no need to do it yourself. This is also over engineered for what it does. Remove and just build the rule name directly.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know about that. I'll delete it later.

Comment on lines +146 to +163
for (auto &s : data.preserved_tokens) {
s.resize(std::distance(s.begin(), std::find_if(s.rbegin(), s.rend(), [](unsigned char ch) {
return !std::isspace(ch);
}).base()));
size_t start = 0;
while (start < s.size() && std::isspace(static_cast<unsigned char>(s[start]))) {
++start;
}
if (start != 0) {
s.erase(0, start);
}
}
data.preserved_tokens.erase(std::remove_if(
data.preserved_tokens.begin(),
data.preserved_tokens.end(),
[](const std::string &s) { return s.size() < 2; }
), data.preserved_tokens.end());
sort_uniq(data.preserved_tokens);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here is where we start to see problems with this "generalized" approach. Models may have special tokens that include whitespace. By stripping it, you will not include the proper token.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It doesn’t break anything for now. As you mentioned:

  1. I don't believe code should be left "in case we need it." It can be added when needed. Many times, YAGNI.

Removing the whitespace actually makes the implementation simpler.


std::string param_rules;
if (parameters.contains("properties")) {
std::vector<std::string> requiredParameters;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not an std::unordered_set?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm… I’m not sure I fully understand your point. Could you explain a bit more why we should use a more complex std::unordered_set instead of a simple std::vector?

Comment on lines +242 to +244
// grammar trigger for tool call
data.grammar_lazy = true;
data.grammar_triggers.push_back({ COMMON_GRAMMAR_TRIGGER_TYPE_WORD, form.scope_start + form.tool_start });
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What if tool_choice = required? As it stands, that seems like it would be a nightmare to implement in a generalized fashion.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This part of the code is copied from the original chat.cpp, as many other models use it. It’s a temporary but simple solution for now.

Maybe I will implement a full grammar for handling thinking blocks, markdown content, and tool calls later. : )

* Parse content uses reasoning and XML-Style tool call
* TODO: Note that form.allow_toolcall_in_think is not tested yet. If anyone confirms it works, this comment can be removed.
*/
inline void parse_msg_with_xml_tool_calls(common_chat_msg_parser & builder, const struct xml_tool_call_format & form, const std::string & start_think = "<think>", const std::string & end_think = "</think>") {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like a very large function for inline.

Nonetheless, I don't think this is necessary. There is already a reasoning parsing function and it should be preferred.

Copy link
Author

@hksdpc255 hksdpc255 Nov 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function is only used once, and cannot be used outside of this file. Inlining it shouldn’t be a problem, I think.

}
};
// Escape string literal to regex that match the literal
constexpr auto escape_regex = [](const std::string &s) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's wrong with the existing regex_escape() function? Also seems way over engineered...

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I reviewed it earlier and found it unreliable for my use case. It would be a nightmare to debug my parser using that. So I implemented a simpler and more predictable version for the parser.

Comment on lines +300 to +318
// handle unclosed top-level primitive
if (err_loc.position != 0 && !healing_marker.empty() && err_loc.stack.empty()) {
std::string str(it, temptative_end);
const auto & magic_seed = out.healing_marker.marker = healing_marker;
if (can_parse(str + "\"")) {
// Was inside an string
str += (out.healing_marker.json_dump_marker = magic_seed) + "\"";
} else if (str[str.length() - 1] == '\\' && can_parse(str + "\\\"")) {
// Was inside an string after an escape
str += (out.healing_marker.json_dump_marker = "\\" + magic_seed) + "\"";
} else {
// TODO: handle more unclosed top-level primitive if the stack was empty but we got an error (e.g. "tru", "\"", etc...)
// fprintf(stderr, "Closing: TODO\n");
return false;
}
out.json = json::parse(str);
it = temptative_end;
return true;
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This needs a test case.

ServeurpersoCom added a commit to ServeurpersoCom/llama.cpp that referenced this pull request Nov 8, 2025
…ranch16 — unified XML tool-call parser + streaming reasoning + GLM4.5/4.6 + MiniMax-M2
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

testing Everything test related

Projects

None yet

Development

Successfully merging this pull request may close these issues.

10 participants